import os
import random
import argparse
import yaml
from tqdm import tqdm

import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision.transforms as transforms
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from scipy.optimize import linear_sum_assignment

# Domain Generalization
from datasets.imagenet import ImageNet
from datasets.imagenet_a import ImageNetA
from datasets.imagenet_r import ImageNetR
from datasets.imagenet_sketch import ImageNetSketch
from datasets.imagenet_v2 import ImageNet_V2

# Cross-Dataset
from datasets.caltech101 import Caltech101
from datasets.dtd import DescribableTextures
from datasets.eurosat import EuroSAT
from datasets.sun397 import SUN397
from datasets.fgvc import FGVCAircraft
from datasets.oxford_flowers import OxfordFlowers
from datasets.oxford_pets import OxfordPets
from datasets.ucf101 import UCF101
from datasets.stanford_cars import StanfordCars
from datasets.food101 import Food101
from datasets import build_dataset
from datasets.utils import build_data_loader

import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
from utils import *
from torch.autograd import Variable
import numpy as np


_tokenizer = _Tokenizer()
train_tranform = transforms.Compose([
    transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
])


def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', dest='config', help='settings of Tip-Adapter in yaml format')
    parser.add_argument('--datasets', dest='datasets', type=str, required=True,
                        help="Datasets to process, separated by a slash (/). Example: I/A/V/R/S")
    parser.add_argument('--data_root', dest='data_root', type=str, default='/data1/',
                        help='Path to the datasets directory. Default is /data1/')
    parser.add_argument('--backbone', dest='backbone', type=str, choices=['RN50', 'ViT-B/16'], required=True,
                        help='CLIP model backbone to use: RN50 or ViT-B/16.')
    parser.add_argument('--infer_type', dest='infer_type', type=str, choices=['COS'],
                        help='Use method for tta.')
    args = parser.parse_args()
    return args


def cos_infer(test_feature, pseudo_means):
    sim = torch.nn.functional.cosine_similarity(test_feature, pseudo_means, dim=1)
    cos_logit = sim.type(torch.float16)
    return cos_logit


def logits_fix(logits, miss_classes):
    for miss_class in miss_classes:
        logits_l = logits[:miss_class].tolist()
        logits_r = logits[miss_class:].tolist()
        logits_l.append(0.)
        logits = torch.tensor((logits_l+logits_r))

    return logits.cuda()


def get_missed_classes(clip_weights, mus_labels):

    class_list = list(range(clip_weights.shape[-1]))
    miss_classes = list(set(class_list) - set(mus_labels.tolist()))
    return miss_classes


def hyper_parm_search_run_tta(pseudo_means, loader, clip_model, clip_weights, test_features_sup,
                              pseudo_labels, cache=None, infer_type='COS', alpha=10., cfg=None):
    # Evaluate
    best_val_acc = 0

    acc = run_tta(pseudo_means, loader, clip_model, clip_weights, test_features_sup,
                      pseudo_labels, cache, infer_type=infer_type, alpha=10, cfg=cfg)

    print("Best Acc: {:.2f}".format(best_val_acc))
    print("Best Alpha: {:.4f}".format(alpha))
    return best_val_acc, alpha

def run_tta(pseudo_means, loader, clip_model, clip_weights, test_features_sup, pseudo_labels,
            cache=None, infer_type='COS', alpha=10., cfg=None):
    print("---- Current SOBA's Alpha: {:.4f}. ----\n".format(alpha))
    accuracies=[]
    with torch.no_grad():
        # Test-time adaptation
        for i, (images, target) in enumerate(tqdm(loader, desc='Processed test images: ')):
            image_features, clip_logits, _, _ = get_clip_logits(images, clip_model, clip_weights)
            target = target.cuda()

            if infer_type == 'COS':
                cos_logits = cos_infer(image_features, pseudo_means)
                if len(miss_classes) != 0:
                    cos_logits = logits_fix(cos_logits, miss_classes)
                final_logits = clip_logits.float() + cos_logits.float() * alpha


            acc = cls_acc(final_logits, target)
            accuracies.append(acc)
            if i % 1000 == 0 and i != 0:
                print("---- SOBA's test accuracy: {:.2f}. ----\n".format(sum(accuracies) / len(accuracies)))
        print("---- SOBA's final test accuracy: {:.2f}. ----\n".format(sum(accuracies) / len(accuracies)))
    return sum(accuracies)/len(accuracies)


def main():
    # Load config file
    args = get_arguments()
    assert (os.path.exists(args.config))
    infer_type = args.infer_type
    cfg = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)

    # CLIP
    backbone = args.backbone
    clip_model, preprocess = clip.load(backbone)
    clip_model = clip_model.cuda()
    clip_model.eval()
    for p in clip_model.parameters():
        p.requires_grad = False
    print("Preparing dataset.")

    # Get test loader
    dataset_name = args.datasets

    test_loader, classnames, template, test_loader_sup = build_test_data_loader(dataset_name, args.data_root, preprocess)
    clip_weights = clip_classifier(classnames, template, clip_model)

    # Get SOBA for testloader
    test_features_sup, test_labels_sup, cache = pre_load_features(cfg['cache'], clip_model,
                                                                  test_loader_sup, clip_weights, need_cache=True)

    #process_cache(cache)
    #infer_type: GDA COS L1 L2
    acc, alpha = hyper_parm_search_run_tta(pseudo_means, test_loader, clip_model,
                                               clip_weights, test_features_sup, pseudo_labels,
                                               cache=cache, infer_type=infer_type, alpha=0.0001)



if __name__ == '__main__':
    main()